{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "84f2925d-46dc-40b4-a49d-dfa7a89904c9",
   "metadata": {},
   "source": [
    "# L2: Multimodal Search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5384702b",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b76146",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "```\n",
    "    !pip install -U weaviate-client\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e516888-e811-4afa-ae23-1703f3914d76",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "385275eb-b09d-43f5-8c34-48a0e356a38a",
   "metadata": {},
   "source": [
    "## Setup\n",
    "### Load environment variables and API keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8a90786-0a3f-41a7-8e43-47b3efe7a075",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "EMBEDDING_API_KEY = os.getenv(\"EMBEDDING_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8cf5927-fd56-4dd8-b0e4-b1b05eba3c19",
   "metadata": {},
   "source": [
    "## Connect to Weaviate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5078b460-4527-4c71-a296-1d8ba8a15470",
   "metadata": {
    "height": 251
   },
   "outputs": [],
   "source": [
    "import weaviate, os\n",
    "\n",
    "client = weaviate.connect_to_embedded(\n",
    "    version=\"1.24.4\",\n",
    "    environment_variables={\n",
    "        \"ENABLE_MODULES\": \"backup-filesystem,multi2vec-palm\",\n",
    "        \"BACKUP_FILESYSTEM_PATH\": \"/home/jovyan/work/backups\",\n",
    "    },\n",
    "    headers={\n",
    "        \"X-PALM-Api-Key\": EMBEDDING_API_KEY,\n",
    "    }\n",
    ")\n",
    "\n",
    "client.is_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42ac7645-7cdd-4178-9785-cc09af78d7fe",
   "metadata": {},
   "source": [
    "## Create the Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41a1d904-612d-49b0-8582-d0f52e847da3",
   "metadata": {
    "height": 302
   },
   "outputs": [],
   "source": [
    "from weaviate.classes.config import Configure\n",
    "\n",
    "# Just checking if you ever need to re run it\n",
    "if(client.collections.exists(\"Animals\")):\n",
    "    client.collections.delete(\"Animals\")\n",
    "    \n",
    "client.collections.create(\n",
    "    name=\"Animals\",\n",
    "    vectorizer_config=Configure.Vectorizer.multi2vec_palm(\n",
    "        image_fields=[\"image\"],\n",
    "        video_fields=[\"video\"],\n",
    "        project_id=\"semi-random-dev\",\n",
    "        location=\"us-central1\",\n",
    "        model_id=\"multimodalembedding@001\",\n",
    "        dimensions=1408,        \n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a96e7749-0d45-47f9-acdf-43a5dd5f3748",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2b6721d-67a9-4af8-88ef-a619ac371bf1",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "import base64\n",
    "\n",
    "# Helper function to convert a file to base64 representation\n",
    "def toBase64(path):\n",
    "    with open(path, 'rb') as file:\n",
    "        return base64.b64encode(file.read()).decode('utf-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f636882-43a7-4b78-9d1f-c5c2819a3b6f",
   "metadata": {},
   "source": [
    "## Insert Images into Weaviate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "375e1b91-b3f9-4279-aaa7-4ae140d398f7",
   "metadata": {
    "height": 285
   },
   "outputs": [],
   "source": [
    "animals = client.collections.get(\"Animals\")\n",
    "\n",
    "source = os.listdir(\"./source/animal_image/\")\n",
    "\n",
    "with animals.batch.rate_limit(requests_per_minute=100) as batch:\n",
    "    for name in source:\n",
    "        print(f\"Adding {name}\")\n",
    "        \n",
    "        path = \"./source/image/\" + name\n",
    "    \n",
    "        batch.add_object({\n",
    "            \"name\": name,            # name of the file\n",
    "            \"path\": path,            # path to the file to display result\n",
    "            \"image\": toBase64(path), # this gets vectorized - \"image\" was configured in vectorizer_config as the property holding images\n",
    "            \"mediaType\": \"image\",    # a label telling us how to display the resource \n",
    "        })"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76406454",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> 💻 &nbsp; <b>Access Utils File and Helper Functions:</b> To access the files for this notebook, 1) click on the <em>\"File\"</em> option on the top menu of the notebook and then 2) click on <em>\"Open\"</em>. For more help, please see the <em>\"Appendix - Tips and Help\"</em> Lesson.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43d6b5dd-fe14-4dd1-8f88-7cdb2c4df18d",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "# Check for failed objects\n",
    "if len(animals.batch.failed_objects) > 0:\n",
    "    print(f\"Failed to import {len(animals.batch.failed_objects)} objects\")\n",
    "    for failed in animals.batch.failed_objects:\n",
    "        print(f\"e.g. Failed to import object with error: {failed.message}\")\n",
    "else:\n",
    "    print(\"No errors\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d8f3bcd-8931-431c-becc-7a8529c9b4b0",
   "metadata": {},
   "source": [
    "## Insert Video Files into Weaviate\n",
    "> Note: the input video must be at least 4 seconds long."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c667d9e3-e48e-4aa3-b975-df61c3c7bdb3",
   "metadata": {
    "height": 268
   },
   "outputs": [],
   "source": [
    "animals = client.collections.get(\"Animals\")\n",
    "\n",
    "source = os.listdir(\"./source/video/\")\n",
    "\n",
    "for name in source:\n",
    "    print(f\"Adding {name}\")\n",
    "    path = \"./source/video/\" + name    \n",
    "\n",
    "    # insert videos one by one\n",
    "    animals.data.insert({\n",
    "        \"name\": name,\n",
    "        \"path\": path,\n",
    "        \"video\": toBase64(path),\n",
    "        \"mediaType\": \"video\"\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f6e7ae5-3e2d-410b-aca6-1151530a9329",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "# Check for failed objects\n",
    "if len(animals.batch.failed_objects) > 0:\n",
    "    print(f\"Failed to import {len(animals.batch.failed_objects)} objects\")\n",
    "    for failed in animals.batch.failed_objects:\n",
    "        print(f\"e.g. Failed to import object with error: {failed.message}\")\n",
    "else:\n",
    "    print(\"No errors\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0f28c8d-9227-4ae3-9461-6037b7ed3902",
   "metadata": {},
   "source": [
    "## Check count\n",
    "> Total count should be 15 (9x image + 6x video)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4750a4d8-9dfb-43b4-add5-564d00d3655e",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "agg = animals.aggregate.over_all(\n",
    "    group_by=\"mediaType\"\n",
    ")\n",
    "\n",
    "for group in agg.groups:\n",
    "    print(group)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dcbeae4",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> 💻 &nbsp; <b>Access Utils File and Helper Functions:</b> To access the files for this notebook, 1) click on the <em>\"File\"</em> option on the top menu of the notebook and then 2) click on <em>\"Open\"</em>. For more help, please see the <em>\"Appendix - Tips and Help\"</em> Lesson.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc610dae-f996-4a25-9dc0-f77fba576144",
   "metadata": {},
   "source": [
    "## Build MultiModal Search\n",
    "### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87099e43-2d4e-40f4-a3dc-aaf2508bc4f0",
   "metadata": {
    "height": 268
   },
   "outputs": [],
   "source": [
    "# Helper functions to display results\n",
    "import json\n",
    "from IPython.display import Image, Video\n",
    "\n",
    "def json_print(data):\n",
    "    print(json.dumps(data, indent=2))\n",
    "\n",
    "def display_media(item):\n",
    "    path = item[\"path\"]\n",
    "\n",
    "    if(item[\"mediaType\"] == \"image\"):\n",
    "        display(Image(path, width=300))\n",
    "\n",
    "    elif(item[\"mediaType\"] == \"video\"):\n",
    "        display(Video(path, width=300))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cadd7bbb-7f97-46f0-9ac1-36020974352f",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "import base64, requests\n",
    "\n",
    "# Helper function – get base64 representation from an online image\n",
    "def url_to_base64(url):\n",
    "    image_response = requests.get(url)\n",
    "    content = image_response.content\n",
    "    return base64.b64encode(content).decode('utf-8')\n",
    "\n",
    "# Helper function - get base64 representation from a local file\n",
    "def file_to_base64(path):\n",
    "    with open(path, 'rb') as file:\n",
    "        return base64.b64encode(file.read()).decode('utf-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c090eb5-074d-43f4-ad79-dd8b70fbb348",
   "metadata": {},
   "source": [
    "## Text to Media Search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44663baa",
   "metadata": {},
   "source": [
    "> Where the fun begins!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e2401bf-6dd3-41ec-87fd-7e06e7151c8e",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "animals = client.collections.get(\"Animals\")\n",
    "\n",
    "response = animals.query.near_text(\n",
    "    query=\"dog playing with stick\",\n",
    "    return_properties=['name','path','mediaType'],\n",
    "    limit=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "654e1804-b1cd-4172-bcbf-a9c8f9feabd9",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "for obj in response.objects:\n",
    "    json_print(obj.properties)\n",
    "    display_media(obj.properties)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8581afdb",
   "metadata": {},
   "source": [
    "> Note: Please be aware that the output from the previous cell may differ from what is shown in the video. This variation is normal and should not cause concern."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c901ac2-ad7f-4b0f-bfc5-d4a73007b6ee",
   "metadata": {},
   "source": [
    "## Image to Media Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c66858c-d241-4e23-94ea-e7f9eb12df92",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# Use this image as an input for the query\n",
    "Image(\"./test/test-cat.jpg\", width=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c10059d-661c-4785-af72-9c1cacd77089",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "# The query\n",
    "response = animals.query.near_image(\n",
    "    near_image=file_to_base64(\"./test/test-cat.jpg\"),\n",
    "    return_properties=['name','path','mediaType'],\n",
    "    limit=3\n",
    ")\n",
    "\n",
    "for obj in response.objects:\n",
    "    json_print(obj.properties)\n",
    "    display_media(obj.properties)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1362103-71b8-4a5e-8667-7b6b3b913963",
   "metadata": {},
   "source": [
    "## Image search - from web URL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df941712-23f3-454c-aafa-b4ef21be83ff",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "Image(\"https://raw.githubusercontent.com/weaviate-tutorials/multimodal-workshop/main/2-multimodal/test/test-meerkat.jpg\", width=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9b506f-9853-4399-9480-0af4b43f496a",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "# The query\n",
    "response = animals.query.near_image(\n",
    "    near_image=url_to_base64(\"https://raw.githubusercontent.com/weaviate-tutorials/multimodal-workshop/main/2-multimodal/test/test-meerkat.jpg\"),\n",
    "    return_properties=['name','path','mediaType'],\n",
    "    limit=3\n",
    ")\n",
    "\n",
    "for obj in response.objects:\n",
    "    json_print(obj.properties)\n",
    "    display_media(obj.properties)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b9fc064-8d33-4f75-88b2-3886e4d6624f",
   "metadata": {},
   "source": [
    "## Video to Media Search\n",
    "> Note: the input video must be at least 4 seconds long."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "424bb2ae-7849-4dd3-9996-b07fa5233d8e",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "Video(\"./test/test-meerkat.mp4\", width=400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18c53d9-47c9-4c01-be97-a6d84da9e19a",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "from weaviate.classes.query import NearMediaType\n",
    "\n",
    "response = animals.query.near_media(\n",
    "    media=file_to_base64(\"./test/test-meerkat.mp4\"),\n",
    "    media_type=NearMediaType.VIDEO,\n",
    "    return_properties=['name','path','mediaType'],\n",
    "    limit=3\n",
    ")\n",
    "\n",
    "for obj in response.objects:\n",
    "    # json_print(obj.properties)\n",
    "    display_media(obj.properties)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66597dcd-8b6c-4851-a502-ade363b33a6c",
   "metadata": {},
   "source": [
    "## Visualizing a Multimodal Vector Space\n",
    "\n",
    "> To make this more exciting, let's loadup a large dataset!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f130d647-ad90-4e3b-9cbd-6557091e6ae6",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import pandas as pd\n",
    "import umap\n",
    "import umap.plot\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c7b73ef-930d-423d-a737-04fe5128beac",
   "metadata": {},
   "source": [
    "## Load vector embeddings and mediaType from Weaviate "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b71f487c",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "client.backup.restore(\n",
    "    backup_id=\"resources-img-and-vid\",\n",
    "    include_collections=\"Resources\",\n",
    "    backend=\"filesystem\"\n",
    ")\n",
    "\n",
    "# It can take a few seconds for the \"Resources\" collection to be ready.\n",
    "# We add 5 seconds of sleep to make sure it is ready for the next cells to use.\n",
    "import time\n",
    "time.sleep(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c8c118-4354-474d-aae1-972a2942b419",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "# Collection named \"Resources\"\n",
    "collection = client.collections.get(\"Resources\")\n",
    "\n",
    "embs = []\n",
    "labs = []\n",
    "for item in collection.iterator(include_vector=True):\n",
    "    #print(item.properties)\\\n",
    "    labs.append(item.properties['mediaType'])\n",
    "    embs.append(item.vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9acdee3e-4e5e-410a-bb2e-0ae1ed4013f4",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "embs2 = [emb['default'] for emb in embs]\n",
    "\n",
    "emb_df = pd.DataFrame(embs2)\n",
    "labels = pd.Series(labs)\n",
    "\n",
    "labels[labels=='image'] = 0\n",
    "labels[labels=='video'] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1b0780d",
   "metadata": {},
   "source": [
    ">Note: this might take some minutes to complete the execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4566c0e9-8139-469f-9fa4-faa0e81841f2",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "mapper2 = umap.UMAP().fit(emb_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4ad14b4-d4c0-416b-b037-394c7a62b787",
   "metadata": {},
   "source": [
    "## Plot the embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "947f7d9a-0b50-4ec1-bfc1-847b7cb71d64",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 8))\n",
    "umap.plot.points(mapper2, labels=labels, theme='fire')\n",
    "\n",
    "# Show plot\n",
    "plt.title('UMAP Visualiztion of Embedding Space')\n",
    "plt.xlabel('UMAP Dimension 1')\n",
    "plt.ylabel('UMAP Dimension 2')\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b4b65dc-aef5-4080-ba73-41c09567d77f",
   "metadata": {},
   "source": [
    "## Interactive plot of vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f5f4335",
   "metadata": {},
   "source": [
    ">Note: Once you run the following cell, please be aware that on the right-hand side,  there are special buttons available. These buttons enable you to perform various functions such as highlighting and more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ecdb01b-8759-4ce5-956f-5a7f8f1480ae",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "umap.plot.output_notebook()\n",
    "\n",
    "p = umap.plot.interactive(mapper2, labels=labels, theme='fire')\n",
    "\n",
    "umap.plot.show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "966d5bb0-f84e-4d6e-982a-b8f0f2cbc662",
   "metadata": {},
   "source": [
    "## Close the connection to Weaviate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f09a92-02bf-4137-89f8-08793e30f46f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "client.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8127b14c",
   "metadata": {},
   "source": [
    "### Try it yourself! \n",
    "\n",
    "Run any of the cells above with your own images or URL for images or videos!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
